{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using MobileNet for our Monkey Classifer\n",
    "\n",
    "### Loading the MobileNet Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Freeze all layers except the top 4, as we'll only be training the top 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 InputLayer False\n",
      "1 ZeroPadding2D False\n",
      "2 Conv2D False\n",
      "3 BatchNormalization False\n",
      "4 ReLU False\n",
      "5 DepthwiseConv2D False\n",
      "6 BatchNormalization False\n",
      "7 ReLU False\n",
      "8 Conv2D False\n",
      "9 BatchNormalization False\n",
      "10 ReLU False\n",
      "11 ZeroPadding2D False\n",
      "12 DepthwiseConv2D False\n",
      "13 BatchNormalization False\n",
      "14 ReLU False\n",
      "15 Conv2D False\n",
      "16 BatchNormalization False\n",
      "17 ReLU False\n",
      "18 DepthwiseConv2D False\n",
      "19 BatchNormalization False\n",
      "20 ReLU False\n",
      "21 Conv2D False\n",
      "22 BatchNormalization False\n",
      "23 ReLU False\n",
      "24 ZeroPadding2D False\n",
      "25 DepthwiseConv2D False\n",
      "26 BatchNormalization False\n",
      "27 ReLU False\n",
      "28 Conv2D False\n",
      "29 BatchNormalization False\n",
      "30 ReLU False\n",
      "31 DepthwiseConv2D False\n",
      "32 BatchNormalization False\n",
      "33 ReLU False\n",
      "34 Conv2D False\n",
      "35 BatchNormalization False\n",
      "36 ReLU False\n",
      "37 ZeroPadding2D False\n",
      "38 DepthwiseConv2D False\n",
      "39 BatchNormalization False\n",
      "40 ReLU False\n",
      "41 Conv2D False\n",
      "42 BatchNormalization False\n",
      "43 ReLU False\n",
      "44 DepthwiseConv2D False\n",
      "45 BatchNormalization False\n",
      "46 ReLU False\n",
      "47 Conv2D False\n",
      "48 BatchNormalization False\n",
      "49 ReLU False\n",
      "50 DepthwiseConv2D False\n",
      "51 BatchNormalization False\n",
      "52 ReLU False\n",
      "53 Conv2D False\n",
      "54 BatchNormalization False\n",
      "55 ReLU False\n",
      "56 DepthwiseConv2D False\n",
      "57 BatchNormalization False\n",
      "58 ReLU False\n",
      "59 Conv2D False\n",
      "60 BatchNormalization False\n",
      "61 ReLU False\n",
      "62 DepthwiseConv2D False\n",
      "63 BatchNormalization False\n",
      "64 ReLU False\n",
      "65 Conv2D False\n",
      "66 BatchNormalization False\n",
      "67 ReLU False\n",
      "68 DepthwiseConv2D False\n",
      "69 BatchNormalization False\n",
      "70 ReLU False\n",
      "71 Conv2D False\n",
      "72 BatchNormalization False\n",
      "73 ReLU False\n",
      "74 ZeroPadding2D False\n",
      "75 DepthwiseConv2D False\n",
      "76 BatchNormalization False\n",
      "77 ReLU False\n",
      "78 Conv2D False\n",
      "79 BatchNormalization False\n",
      "80 ReLU False\n",
      "81 DepthwiseConv2D False\n",
      "82 BatchNormalization False\n",
      "83 ReLU False\n",
      "84 Conv2D False\n",
      "85 BatchNormalization False\n",
      "86 ReLU False\n"
     ]
    }
   ],
   "source": [
    "from keras.applications import MobileNet\n",
    "\n",
    "# MobileNet was designed to work on 224 x 224 pixel input images sizes\n",
    "img_rows, img_cols = 224, 224 \n",
    "\n",
    "# Re-loads the MobileNet model without the top or FC layers\n",
    "MobileNet = MobileNet(weights = 'imagenet', \n",
    "                 include_top = False, \n",
    "                 input_shape = (img_rows, img_cols, 3))\n",
    "\n",
    "# Here we freeze the last 4 layers \n",
    "# Layers are set to trainable as True by default\n",
    "for layer in MobileNet.layers:\n",
    "    layer.trainable = False\n",
    "    \n",
    "# Let's print our layers \n",
    "for (i,layer) in enumerate(MobileNet.layers):\n",
    "    print(str(i) + \" \"+ layer.__class__.__name__, layer.trainable)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Let's make a function that returns our FC Head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def addTopModelMobileNet(bottom_model, num_classes):\n",
    "    \"\"\"creates the top or head of the model that will be \n",
    "    placed ontop of the bottom layers\"\"\"\n",
    "\n",
    "    top_model = bottom_model.output\n",
    "    top_model = GlobalAveragePooling2D()(top_model)\n",
    "    top_model = Dense(1024,activation='relu')(top_model)\n",
    "    top_model = Dense(1024,activation='relu')(top_model)\n",
    "    top_model = Dense(512,activation='relu')(top_model)\n",
    "    top_model = Dense(num_classes,activation='softmax')(top_model)\n",
    "    return top_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Let's add our FC Head back onto MobileNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_2 (InputLayer)         (None, 224, 224, 3)       0         \n",
      "_________________________________________________________________\n",
      "conv1_pad (ZeroPadding2D)    (None, 225, 225, 3)       0         \n",
      "_________________________________________________________________\n",
      "conv1 (Conv2D)               (None, 112, 112, 32)      864       \n",
      "_________________________________________________________________\n",
      "conv1_bn (BatchNormalization (None, 112, 112, 32)      128       \n",
      "_________________________________________________________________\n",
      "conv1_relu (ReLU)            (None, 112, 112, 32)      0         \n",
      "_________________________________________________________________\n",
      "conv_dw_1 (DepthwiseConv2D)  (None, 112, 112, 32)      288       \n",
      "_________________________________________________________________\n",
      "conv_dw_1_bn (BatchNormaliza (None, 112, 112, 32)      128       \n",
      "_________________________________________________________________\n",
      "conv_dw_1_relu (ReLU)        (None, 112, 112, 32)      0         \n",
      "_________________________________________________________________\n",
      "conv_pw_1 (Conv2D)           (None, 112, 112, 64)      2048      \n",
      "_________________________________________________________________\n",
      "conv_pw_1_bn (BatchNormaliza (None, 112, 112, 64)      256       \n",
      "_________________________________________________________________\n",
      "conv_pw_1_relu (ReLU)        (None, 112, 112, 64)      0         \n",
      "_________________________________________________________________\n",
      "conv_pad_2 (ZeroPadding2D)   (None, 113, 113, 64)      0         \n",
      "_________________________________________________________________\n",
      "conv_dw_2 (DepthwiseConv2D)  (None, 56, 56, 64)        576       \n",
      "_________________________________________________________________\n",
      "conv_dw_2_bn (BatchNormaliza (None, 56, 56, 64)        256       \n",
      "_________________________________________________________________\n",
      "conv_dw_2_relu (ReLU)        (None, 56, 56, 64)        0         \n",
      "_________________________________________________________________\n",
      "conv_pw_2 (Conv2D)           (None, 56, 56, 128)       8192      \n",
      "_________________________________________________________________\n",
      "conv_pw_2_bn (BatchNormaliza (None, 56, 56, 128)       512       \n",
      "_________________________________________________________________\n",
      "conv_pw_2_relu (ReLU)        (None, 56, 56, 128)       0         \n",
      "_________________________________________________________________\n",
      "conv_dw_3 (DepthwiseConv2D)  (None, 56, 56, 128)       1152      \n",
      "_________________________________________________________________\n",
      "conv_dw_3_bn (BatchNormaliza (None, 56, 56, 128)       512       \n",
      "_________________________________________________________________\n",
      "conv_dw_3_relu (ReLU)        (None, 56, 56, 128)       0         \n",
      "_________________________________________________________________\n",
      "conv_pw_3 (Conv2D)           (None, 56, 56, 128)       16384     \n",
      "_________________________________________________________________\n",
      "conv_pw_3_bn (BatchNormaliza (None, 56, 56, 128)       512       \n",
      "_________________________________________________________________\n",
      "conv_pw_3_relu (ReLU)        (None, 56, 56, 128)       0         \n",
      "_________________________________________________________________\n",
      "conv_pad_4 (ZeroPadding2D)   (None, 57, 57, 128)       0         \n",
      "_________________________________________________________________\n",
      "conv_dw_4 (DepthwiseConv2D)  (None, 28, 28, 128)       1152      \n",
      "_________________________________________________________________\n",
      "conv_dw_4_bn (BatchNormaliza (None, 28, 28, 128)       512       \n",
      "_________________________________________________________________\n",
      "conv_dw_4_relu (ReLU)        (None, 28, 28, 128)       0         \n",
      "_________________________________________________________________\n",
      "conv_pw_4 (Conv2D)           (None, 28, 28, 256)       32768     \n",
      "_________________________________________________________________\n",
      "conv_pw_4_bn (BatchNormaliza (None, 28, 28, 256)       1024      \n",
      "_________________________________________________________________\n",
      "conv_pw_4_relu (ReLU)        (None, 28, 28, 256)       0         \n",
      "_________________________________________________________________\n",
      "conv_dw_5 (DepthwiseConv2D)  (None, 28, 28, 256)       2304      \n",
      "_________________________________________________________________\n",
      "conv_dw_5_bn (BatchNormaliza (None, 28, 28, 256)       1024      \n",
      "_________________________________________________________________\n",
      "conv_dw_5_relu (ReLU)        (None, 28, 28, 256)       0         \n",
      "_________________________________________________________________\n",
      "conv_pw_5 (Conv2D)           (None, 28, 28, 256)       65536     \n",
      "_________________________________________________________________\n",
      "conv_pw_5_bn (BatchNormaliza (None, 28, 28, 256)       1024      \n",
      "_________________________________________________________________\n",
      "conv_pw_5_relu (ReLU)        (None, 28, 28, 256)       0         \n",
      "_________________________________________________________________\n",
      "conv_pad_6 (ZeroPadding2D)   (None, 29, 29, 256)       0         \n",
      "_________________________________________________________________\n",
      "conv_dw_6 (DepthwiseConv2D)  (None, 14, 14, 256)       2304      \n",
      "_________________________________________________________________\n",
      "conv_dw_6_bn (BatchNormaliza (None, 14, 14, 256)       1024      \n",
      "_________________________________________________________________\n",
      "conv_dw_6_relu (ReLU)        (None, 14, 14, 256)       0         \n",
      "_________________________________________________________________\n",
      "conv_pw_6 (Conv2D)           (None, 14, 14, 512)       131072    \n",
      "_________________________________________________________________\n",
      "conv_pw_6_bn (BatchNormaliza (None, 14, 14, 512)       2048      \n",
      "_________________________________________________________________\n",
      "conv_pw_6_relu (ReLU)        (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_dw_7 (DepthwiseConv2D)  (None, 14, 14, 512)       4608      \n",
      "_________________________________________________________________\n",
      "conv_dw_7_bn (BatchNormaliza (None, 14, 14, 512)       2048      \n",
      "_________________________________________________________________\n",
      "conv_dw_7_relu (ReLU)        (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_pw_7 (Conv2D)           (None, 14, 14, 512)       262144    \n",
      "_________________________________________________________________\n",
      "conv_pw_7_bn (BatchNormaliza (None, 14, 14, 512)       2048      \n",
      "_________________________________________________________________\n",
      "conv_pw_7_relu (ReLU)        (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_dw_8 (DepthwiseConv2D)  (None, 14, 14, 512)       4608      \n",
      "_________________________________________________________________\n",
      "conv_dw_8_bn (BatchNormaliza (None, 14, 14, 512)       2048      \n",
      "_________________________________________________________________\n",
      "conv_dw_8_relu (ReLU)        (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_pw_8 (Conv2D)           (None, 14, 14, 512)       262144    \n",
      "_________________________________________________________________\n",
      "conv_pw_8_bn (BatchNormaliza (None, 14, 14, 512)       2048      \n",
      "_________________________________________________________________\n",
      "conv_pw_8_relu (ReLU)        (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_dw_9 (DepthwiseConv2D)  (None, 14, 14, 512)       4608      \n",
      "_________________________________________________________________\n",
      "conv_dw_9_bn (BatchNormaliza (None, 14, 14, 512)       2048      \n",
      "_________________________________________________________________\n",
      "conv_dw_9_relu (ReLU)        (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_pw_9 (Conv2D)           (None, 14, 14, 512)       262144    \n",
      "_________________________________________________________________\n",
      "conv_pw_9_bn (BatchNormaliza (None, 14, 14, 512)       2048      \n",
      "_________________________________________________________________\n",
      "conv_pw_9_relu (ReLU)        (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_dw_10 (DepthwiseConv2D) (None, 14, 14, 512)       4608      \n",
      "_________________________________________________________________\n",
      "conv_dw_10_bn (BatchNormaliz (None, 14, 14, 512)       2048      \n",
      "_________________________________________________________________\n",
      "conv_dw_10_relu (ReLU)       (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_pw_10 (Conv2D)          (None, 14, 14, 512)       262144    \n",
      "_________________________________________________________________\n",
      "conv_pw_10_bn (BatchNormaliz (None, 14, 14, 512)       2048      \n",
      "_________________________________________________________________\n",
      "conv_pw_10_relu (ReLU)       (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_dw_11 (DepthwiseConv2D) (None, 14, 14, 512)       4608      \n",
      "_________________________________________________________________\n",
      "conv_dw_11_bn (BatchNormaliz (None, 14, 14, 512)       2048      \n",
      "_________________________________________________________________\n",
      "conv_dw_11_relu (ReLU)       (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_pw_11 (Conv2D)          (None, 14, 14, 512)       262144    \n",
      "_________________________________________________________________\n",
      "conv_pw_11_bn (BatchNormaliz (None, 14, 14, 512)       2048      \n",
      "_________________________________________________________________\n",
      "conv_pw_11_relu (ReLU)       (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_pad_12 (ZeroPadding2D)  (None, 15, 15, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_dw_12 (DepthwiseConv2D) (None, 7, 7, 512)         4608      \n",
      "_________________________________________________________________\n",
      "conv_dw_12_bn (BatchNormaliz (None, 7, 7, 512)         2048      \n",
      "_________________________________________________________________\n",
      "conv_dw_12_relu (ReLU)       (None, 7, 7, 512)         0         \n",
      "_________________________________________________________________\n",
      "conv_pw_12 (Conv2D)          (None, 7, 7, 1024)        524288    \n",
      "_________________________________________________________________\n",
      "conv_pw_12_bn (BatchNormaliz (None, 7, 7, 1024)        4096      \n",
      "_________________________________________________________________\n",
      "conv_pw_12_relu (ReLU)       (None, 7, 7, 1024)        0         \n",
      "_________________________________________________________________\n",
      "conv_dw_13 (DepthwiseConv2D) (None, 7, 7, 1024)        9216      \n",
      "_________________________________________________________________\n",
      "conv_dw_13_bn (BatchNormaliz (None, 7, 7, 1024)        4096      \n",
      "_________________________________________________________________\n",
      "conv_dw_13_relu (ReLU)       (None, 7, 7, 1024)        0         \n",
      "_________________________________________________________________\n",
      "conv_pw_13 (Conv2D)          (None, 7, 7, 1024)        1048576   \n",
      "_________________________________________________________________\n",
      "conv_pw_13_bn (BatchNormaliz (None, 7, 7, 1024)        4096      \n",
      "_________________________________________________________________\n",
      "conv_pw_13_relu (ReLU)       (None, 7, 7, 1024)        0         \n",
      "_________________________________________________________________\n",
      "global_average_pooling2d_4 ( (None, 1024)              0         \n",
      "_________________________________________________________________\n",
      "dense_13 (Dense)             (None, 1024)              1049600   \n",
      "_________________________________________________________________\n",
      "dense_14 (Dense)             (None, 1024)              1049600   \n",
      "_________________________________________________________________\n",
      "dense_15 (Dense)             (None, 512)               524800    \n",
      "_________________________________________________________________\n",
      "dense_16 (Dense)             (None, 10)                5130      \n",
      "=================================================================\n",
      "Total params: 5,857,994\n",
      "Trainable params: 2,629,130\n",
      "Non-trainable params: 3,228,864\n",
      "_________________________________________________________________\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "from keras.models import Sequential\n",
    "from keras.layers import Dense, Dropout, Activation, Flatten, GlobalAveragePooling2D\n",
    "from keras.layers import Conv2D, MaxPooling2D, ZeroPadding2D\n",
    "from keras.layers.normalization import BatchNormalization\n",
    "from keras.models import Model\n",
    "\n",
    "# Set our class number to 3 (Young, Middle, Old)\n",
    "num_classes = 10\n",
    "\n",
    "FC_Head = addTopModelMobileNet(MobileNet, num_classes)\n",
    "\n",
    "model = Model(inputs = MobileNet.input, outputs = FC_Head)\n",
    "\n",
    "print(model.summary())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading our Monkey Breed Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 1097 images belonging to 10 classes.\n",
      "Found 272 images belonging to 10 classes.\n"
     ]
    }
   ],
   "source": [
    "from keras.preprocessing.image import ImageDataGenerator\n",
    "\n",
    "train_data_dir = './monkey_breed/train'\n",
    "validation_data_dir = './monkey_breed/validation'\n",
    "\n",
    "# Let's use some data augmentaiton \n",
    "train_datagen = ImageDataGenerator(\n",
    "      rescale=1./255,\n",
    "      rotation_range=45,\n",
    "      width_shift_range=0.3,\n",
    "      height_shift_range=0.3,\n",
    "      horizontal_flip=True,\n",
    "      fill_mode='nearest')\n",
    " \n",
    "validation_datagen = ImageDataGenerator(rescale=1./255)\n",
    " \n",
    "# set our batch size (typically on most mid tier systems we'll use 16-32)\n",
    "batch_size = 32\n",
    " \n",
    "train_generator = train_datagen.flow_from_directory(\n",
    "        train_data_dir,\n",
    "        target_size=(img_rows, img_cols),\n",
    "        batch_size=batch_size,\n",
    "        class_mode='categorical')\n",
    " \n",
    "validation_generator = validation_datagen.flow_from_directory(\n",
    "        validation_data_dir,\n",
    "        target_size=(img_rows, img_cols),\n",
    "        batch_size=batch_size,\n",
    "        class_mode='categorical')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training out Model\n",
    "- Note we're using checkpointing and early stopping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'model' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-34-93509b269ef0>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     19\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     20\u001b[0m \u001b[0;31m# We use a very small learning rate\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m model.compile(loss = 'categorical_crossentropy',\n\u001b[0m\u001b[1;32m     22\u001b[0m               \u001b[0moptimizer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mRMSprop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0.001\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     23\u001b[0m               metrics = ['accuracy'])\n",
      "\u001b[0;31mNameError\u001b[0m: name 'model' is not defined"
     ]
    }
   ],
   "source": [
    "from keras.optimizers import RMSprop\n",
    "from keras.callbacks import ModelCheckpoint, EarlyStopping\n",
    "\n",
    "                     \n",
    "checkpoint = ModelCheckpoint(\"/home/deeplearningcv/DeepLearningCV/Trained Models/monkey_breed_mobileNet.h5\",\n",
    "                             monitor=\"val_loss\",\n",
    "                             mode=\"min\",\n",
    "                             save_best_only = True,\n",
    "                             verbose=1)\n",
    "\n",
    "earlystop = EarlyStopping(monitor = 'val_loss', \n",
    "                          min_delta = 0, \n",
    "                          patience = 3,\n",
    "                          verbose = 1,\n",
    "                          restore_best_weights = True)\n",
    "\n",
    "# we put our call backs into a callback list\n",
    "callbacks = [earlystop, checkpoint]\n",
    "\n",
    "# We use a very small learning rate \n",
    "model.compile(loss = 'categorical_crossentropy',\n",
    "              optimizer = RMSprop(lr = 0.001),\n",
    "              metrics = ['accuracy'])\n",
    "\n",
    "# Enter the number of training and validation samples here\n",
    "nb_train_samples = 1097\n",
    "nb_validation_samples = 272\n",
    "\n",
    "# We only train 5 EPOCHS \n",
    "epochs = 5\n",
    "batch_size = 16\n",
    "\n",
    "history = model.fit_generator(\n",
    "    train_generator,\n",
    "    steps_per_epoch = nb_train_samples // batch_size,\n",
    "    epochs = epochs,\n",
    "    callbacks = callbacks,\n",
    "    validation_data = validation_generator,\n",
    "    validation_steps = nb_validation_samples // batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading our classifer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.models import load_model\n",
    "\n",
    "classifier = load_model('/home/deeplearningcv/DeepLearningCV/Trained Models/monkey_breed_mobileNet.h5')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Testing our classifer on some test images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Class - mantled_howler \n",
      "Class - patas_monkey\n",
      "Class - patas_monkey\n",
      "Class - silvery_marmoset\n",
      "Class - black_headed_night_monkey\n",
      "Class - pygmy_marmoset \n",
      "Class - silvery_marmoset\n",
      "Class - mantled_howler \n",
      "Class - common_squirrel_monkey\n",
      "Class - patas_monkey\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import cv2\n",
    "import numpy as np\n",
    "from os import listdir\n",
    "from os.path import isfile, join\n",
    "\n",
    "monkey_breeds_dict = {\"[0]\": \"mantled_howler \", \n",
    "                      \"[1]\": \"patas_monkey\",\n",
    "                      \"[2]\": \"bald_uakari\",\n",
    "                      \"[3]\": \"japanese_macaque\",\n",
    "                      \"[4]\": \"pygmy_marmoset \",\n",
    "                      \"[5]\": \"white_headed_capuchin\",\n",
    "                      \"[6]\": \"silvery_marmoset\",\n",
    "                      \"[7]\": \"common_squirrel_monkey\",\n",
    "                      \"[8]\": \"black_headed_night_monkey\",\n",
    "                      \"[9]\": \"nilgiri_langur\"}\n",
    "\n",
    "monkey_breeds_dict_n = {\"n0\": \"mantled_howler \", \n",
    "                      \"n1\": \"patas_monkey\",\n",
    "                      \"n2\": \"bald_uakari\",\n",
    "                      \"n3\": \"japanese_macaque\",\n",
    "                      \"n4\": \"pygmy_marmoset \",\n",
    "                      \"n5\": \"white_headed_capuchin\",\n",
    "                      \"n6\": \"silvery_marmoset\",\n",
    "                      \"n7\": \"common_squirrel_monkey\",\n",
    "                      \"n8\": \"black_headed_night_monkey\",\n",
    "                      \"n9\": \"nilgiri_langur\"}\n",
    "\n",
    "def draw_test(name, pred, im):\n",
    "    monkey = monkey_breeds_dict[str(pred)]\n",
    "    BLACK = [0,0,0]\n",
    "    expanded_image = cv2.copyMakeBorder(im, 80, 0, 0, 100 ,cv2.BORDER_CONSTANT,value=BLACK)\n",
    "    cv2.putText(expanded_image, monkey, (20, 60) , cv2.FONT_HERSHEY_SIMPLEX,1, (0,0,255), 2)\n",
    "    cv2.imshow(name, expanded_image)\n",
    "\n",
    "def getRandomImage(path):\n",
    "    \"\"\"function loads a random images from a random folder in our test path \"\"\"\n",
    "    folders = list(filter(lambda x: os.path.isdir(os.path.join(path, x)), os.listdir(path)))\n",
    "    random_directory = np.random.randint(0,len(folders))\n",
    "    path_class = folders[random_directory]\n",
    "    print(\"Class - \" + monkey_breeds_dict_n[str(path_class)])\n",
    "    file_path = path + path_class\n",
    "    file_names = [f for f in listdir(file_path) if isfile(join(file_path, f))]\n",
    "    random_file_index = np.random.randint(0,len(file_names))\n",
    "    image_name = file_names[random_file_index]\n",
    "    return cv2.imread(file_path+\"/\"+image_name)    \n",
    "\n",
    "for i in range(0,10):\n",
    "    input_im = getRandomImage(\"./monkey_breed/validation/\")\n",
    "    input_original = input_im.copy()\n",
    "    input_original = cv2.resize(input_original, None, fx=0.5, fy=0.5, interpolation = cv2.INTER_LINEAR)\n",
    "    \n",
    "    input_im = cv2.resize(input_im, (224, 224), interpolation = cv2.INTER_LINEAR)\n",
    "    input_im = input_im / 255.\n",
    "    input_im = input_im.reshape(1,224,224,3) \n",
    "    \n",
    "    # Get Prediction\n",
    "    res = np.argmax(classifier.predict(input_im, 1, verbose = 0), axis=1)\n",
    "    \n",
    "    # Show image with predicted class\n",
    "    draw_test(\"Prediction\", res, input_original) \n",
    "    cv2.waitKey(0)\n",
    "\n",
    "cv2.destroyAllWindows()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
